from safetensors import safe_open
= {}
tensors with safe_open("qlora_output/model_state_dict.safetensors", framework="pt", device=0) as f:
for k in f.keys():
= f.get_tensor(k) # loads the full tensor given a key
tensors[k] # print(k, tensors[k].dtype, tensors[k].shape) # Uncomment to view
Converting the State Dict
The training script (train.py
) doesn’t support any fancy saving/checkpointing methods, but it does optionally save the model right at the end of training into a safetensors file. In this notebook we’ll show how to load in these saved weights for downstream evaluation and usage. This should hopefully become unneeded as frameworks integrate the changes needed to make FSDP+QLoRA work natively.
As an example, let’s look at a model trained with the following command (using default settings for LoRA rank etc):
python train.py --save_model True --train_type qlora --output_dir qlora_output
We’ll load the saved state_dict, and then copy the relevant weights into a PEFT model to save via their TODO method.
Let’s start by loading the state dict. If you uncomment the print statement, you’ll see that for every linear layer that had a LoRA adapter, we have something like this:
base_model.model.model.layers.0.mlp.down_proj.base_layer.weight torch.bfloat16 torch.Size([11272192, 1])
base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight torch.bfloat16 torch.Size([8, 11008])
base_model.model.model.layers.0.mlp.down_proj.lora_B.default.weight torch.bfloat16 torch.Size([4096, 8])
The base weights are flattened and quantized 4-bit values, which we won’t need (we’ll load the original base model later), and the lora_A and lora_B adapters are the ones we’re interested in.
To save memory, we can delete everything but the LoRA layers:
for k in tensors:
if 'lora' not in k: tensors[k] = None
Next, we load the base model and add a random adapter:
import torch
from transformers import LlamaForCausalLM, BitsAndBytesConfig
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
# Make sure the compute type, target modules, rank, alpha etc match!
= BitsAndBytesConfig(
bnb_config =True,
load_in_4bit="nf4",
bnb_4bit_quant_type=False,
bnb_4bit_use_double_quant=torch.bfloat16
bnb_4bit_compute_dtype
)= LlamaForCausalLM.from_pretrained(
model "meta-llama/Llama-2-7b-hf",
=False,
use_cache=bnb_config
quantization_config
)
# Freeze
for param in model.parameters():
= False
param.requires_grad
# Add LoRA (make sure your rank (r) and alpha (lora_alpha) values match those used in training!)
= LoraConfig(
peft_config =TaskType.CAUSAL_LM, inference_mode=False, r=64, lora_alpha=16, lora_dropout=0.1,
task_type=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]
target_modules
)= get_peft_model(model, peft_config)
model
# Check out the first few keys in the state dict:
list(model.state_dict().keys())[:10]
['base_model.model.model.embed_tokens.weight',
'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight',
'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.absmax',
'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.quant_map',
'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.quant_state.bitsandbytes__nf4',
'base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight',
'base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight',
'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight',
'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.absmax',
'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.quant_map']
Now, if all goes well, we can replace the randomly initialized LoRA layers with our trained ones:
= model.state_dict()
new_sd for k in new_sd:
if 'lora' in k:
= tensors[k]
new_sd[k]
model.load_state_dict(new_sd)
<All keys matched successfully>
And now, since we have a regular PEFT model, we can save using the built-in methods:
"lora_adapters") model.save_pretrained(
!ls lora_adapters
README.md adapter_config.json adapter_model.safetensors
# model.push_to_hub('your_repo_id') # If you want to share your model...